
import torch 
import os 
import matplotlib.pyplot as plt 
from utils.model2 import YoloModel
import torchvision.transforms as transforms 
import cv2 
from utils.transforms import Resize, DEFAULT_TRANSFORMS
import numpy as np 
from utils.nms import rescale_boxes, non_max_suppression 
import matplotlib.pyplot as plt 
def main():
    device = torch.device("cpu")
    model = YoloModel() 
    model.eval() 
    model.to(device)
    model.load_state_dict(torch.load("ckpt/23.pt", map_location=device))
    image = cv2.imread("data/input.jpg")
    #image = image.copy()[::4, ::4, :]
    H, W, C = image.shape 
    print(image.shape, image.dtype)
    # Configure input
    input_img = transforms.Compose([
        DEFAULT_TRANSFORMS,
        Resize(416)])(
            (image, np.zeros((1, 5))))[0].unsqueeze(0)
    namedict = {}
    with open("ckpt/coco.names", "r", encoding="utf-8") as f:
        for i, line in enumerate(f.readlines()):
            namedict[i] = line.strip()
    with torch.no_grad():
        print(input_img.max(), input_img.min(), input_img.shape)
        #x = input_img.permute(0, 2, 3, 1)
        #x = x.numpy() 
        #plt.imshow(x[0])
        #plt.show()
        detections = model(input_img)
        detections = torch.cat(detections, 1)
        detections = non_max_suppression(detections, 0.1, 0.1)
        detections = rescale_boxes(detections[0], 416, image.shape[:2])
    def between(a, b):
        if a>0 and b>0 and a<W and b<H:
            return True 
        else: 
            return False
    image = image.copy()
    for x1, y1, x2, y2, conf, cls_pred in detections.numpy():
        if between(x1, y1) and between(x2, y2):
            print("OUTPUT", H, W, x1, y1, x2, y2, image.dtype, image.shape)
            cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0))
            predtype = namedict[int(cls_pred)] 
            cv2.putText(image, f"Conf:{conf:.2f},{predtype}", (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 0))
        else:
            print(x1, y1, x2, y2)
    cv2.imwrite("out2.jpg", image)
    cv2.imshow("www", image)
    cv2.waitKey(0)
#nohup /home/yuzy/software/anaconda39/bin/python yolo.train.py > ckpt/yolo.log 2>&1 &
if __name__ == "__main__":
    main()
